{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LtL6pxC3QK6k"
   },
   "source": [
    "# Réponses aux questions (PyTorch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6BeDhSYwQK6m"
   },
   "source": [
    "Installez les bibliothèques 🤗 *Datasets* et 🤗 *Transformers* pour exécuter ce *notebook*."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "aFM3SNcQQK6o"
   },
   "outputs": [],
   "source": [
    "!pip install datasets transformers[sentencepiece]\n",
    "!pip install accelerate\n",
    "# Pour exécuter l'entraînement sur TPU, vous devez décommenter la ligne suivante :\n",
    "# !pip install cloud-tpu-client==0.10 torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl\n",
    "!apt install git-lfs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "scOm799eQK6p"
   },
   "source": [
    "Vous aurez besoin de configurer git, adaptez votre email et votre nom dans la cellule suivante."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "uBSqsaB2QK6q"
   },
   "outputs": [],
   "source": [
    "!git config --global user.email \"you@example.com\"\n",
    "!git config --global user.name \"Your Name\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cUwgvueJQK6r"
   },
   "source": [
    "Vous devrez également être connecté au Hub d'Hugging Face. Exécutez ce qui suit et entrez vos informations d'identification."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "IjUdvDURQK6r"
   },
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "4y0INx16_YJI"
   },
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "raw_datasets = load_dataset(\"piaf\")\n",
    "\n",
    "# Piaf n'ayant pas de jeu de données de test, nous en créons un\n",
    "raw_datasets = raw_datasets['train']\n",
    "raw_datasets = raw_datasets.train_test_split(test_size=0.2, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "j8x478MCQK6t"
   },
   "outputs": [],
   "source": [
    "raw_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "bvADO3btQK6u"
   },
   "outputs": [],
   "source": [
    "print(\"Context: \", raw_datasets[\"train\"][0][\"context\"])\n",
    "print(\"Question: \", raw_datasets[\"train\"][0][\"question\"])\n",
    "print(\"Answer: \", raw_datasets[\"train\"][0][\"answers\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "LkgfLJzKAyL8"
   },
   "outputs": [],
   "source": [
    "raw_datasets[\"train\"].filter(lambda x: len(x[\"answers\"][\"text\"]) != 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "sSVNaNijQK6w"
   },
   "outputs": [],
   "source": [
    "print(raw_datasets[\"test\"][0][\"answers\"])\n",
    "print(raw_datasets[\"test\"][2][\"answers\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tqovc8IQQK6w"
   },
   "outputs": [],
   "source": [
    "print(raw_datasets[\"test\"][2][\"context\"])\n",
    "print(raw_datasets[\"test\"][2][\"question\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "1vIQVMVbQK6x"
   },
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "model_checkpoint = \"camembert-base\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "85KnO70zQK6x"
   },
   "outputs": [],
   "source": [
    "tokenizer.is_fast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "MRNaEIRUQK6y"
   },
   "outputs": [],
   "source": [
    "context = raw_datasets[\"train\"][0][\"context\"]\n",
    "question = raw_datasets[\"train\"][0][\"question\"]\n",
    "\n",
    "inputs = tokenizer(question, context)\n",
    "tokenizer.decode(inputs[\"input_ids\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "P7-ROeq8QK6y"
   },
   "outputs": [],
   "source": [
    "inputs = tokenizer(\n",
    "    question,\n",
    "    context,\n",
    "    max_length=100,\n",
    "    truncation=\"only_second\",\n",
    "    stride=50,\n",
    "    return_overflowing_tokens=True,\n",
    ")\n",
    "\n",
    "for ids in inputs[\"input_ids\"]:\n",
    "    print(tokenizer.decode(ids))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5a-69XqDQK6z"
   },
   "outputs": [],
   "source": [
    "inputs = tokenizer(\n",
    "    question,\n",
    "    context,\n",
    "    max_length=100,\n",
    "    truncation=\"only_second\",\n",
    "    stride=50,\n",
    "    return_overflowing_tokens=True,\n",
    "    return_offsets_mapping=True,\n",
    ")\n",
    "inputs.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "qEIj9Wg4QK6z"
   },
   "outputs": [],
   "source": [
    "inputs[\"overflow_to_sample_mapping\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "JUami5uPQK60"
   },
   "outputs": [],
   "source": [
    "inputs = tokenizer(\n",
    "    raw_datasets[\"train\"][2:6][\"question\"],\n",
    "    raw_datasets[\"train\"][2:6][\"context\"],\n",
    "    max_length=100,\n",
    "    truncation=\"only_second\",\n",
    "    stride=50,\n",
    "    return_overflowing_tokens=True,\n",
    "    return_offsets_mapping=True,\n",
    ")\n",
    "\n",
    "print(f\"The 4 examples gave {len(inputs['input_ids'])} features.\")\n",
    "print(f\"Here is where each comes from: {inputs['overflow_to_sample_mapping']}.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "JsGGWylQQK60"
   },
   "outputs": [],
   "source": [
    "answers = raw_datasets[\"train\"][2:6][\"answers\"]\n",
    "start_positions = []\n",
    "end_positions = []\n",
    "\n",
    "for i, offset in enumerate(inputs[\"offset_mapping\"]):\n",
    "    sample_idx = inputs[\"overflow_to_sample_mapping\"][i]\n",
    "    answer = answers[sample_idx]\n",
    "    start_char = answer[\"answer_start\"][0]\n",
    "    end_char = answer[\"answer_start\"][0] + len(answer[\"text\"][0])\n",
    "    sequence_ids = inputs.sequence_ids(i)\n",
    "\n",
    "    # Trouver le début et la fin du contexte\n",
    "    idx = 0\n",
    "    while sequence_ids[idx] != 1:\n",
    "        idx += 1\n",
    "    context_start = idx\n",
    "    while sequence_ids[idx] == 1:\n",
    "        idx += 1\n",
    "    context_end = idx - 1\n",
    "\n",
    "    # Si la réponse n'est pas entièrement dans le contexte, l'étiquette est (0, 0)\n",
    "    if offset[context_start][0] > start_char or offset[context_end][1] < end_char:\n",
    "        start_positions.append(0)\n",
    "        end_positions.append(0)\n",
    "    else:\n",
    "        # Sinon, ce sont les positions de début et de fin du token\n",
    "        idx = context_start\n",
    "        while idx <= context_end and offset[idx][0] <= start_char:\n",
    "            idx += 1\n",
    "        start_positions.append(idx - 1)\n",
    "\n",
    "        idx = context_end\n",
    "        while idx >= context_start and offset[idx][1] >= end_char:\n",
    "            idx -= 1\n",
    "        end_positions.append(idx + 1)\n",
    "\n",
    "start_positions, end_positions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "T-2yLtjlQK61"
   },
   "outputs": [],
   "source": [
    "idx = 0\n",
    "sample_idx = inputs[\"overflow_to_sample_mapping\"][idx]\n",
    "answer = answers[sample_idx][\"text\"][0]\n",
    "\n",
    "start = start_positions[idx]\n",
    "end = end_positions[idx]\n",
    "labeled_answer = tokenizer.decode(inputs[\"input_ids\"][idx][start : end + 1])\n",
    "\n",
    "print(f\"Theoretical answer: {answer}, labels give: {labeled_answer}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fAgj10qeQK62"
   },
   "outputs": [],
   "source": [
    "idx = 4\n",
    "sample_idx = inputs[\"overflow_to_sample_mapping\"][idx]\n",
    "answer = answers[sample_idx][\"text\"][0]\n",
    "\n",
    "decoded_example = tokenizer.decode(inputs[\"input_ids\"][idx])\n",
    "print(f\"Theoretical answer: {answer}, decoded example: {decoded_example}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "WNhoHB2CQK62"
   },
   "outputs": [],
   "source": [
    "max_length = 384\n",
    "stride = 128\n",
    "\n",
    "\n",
    "def preprocess_training_examples(examples):\n",
    "    questions = [q.strip() for q in examples[\"question\"]]\n",
    "    inputs = tokenizer(\n",
    "        questions,\n",
    "        examples[\"context\"],\n",
    "        max_length=max_length,\n",
    "        truncation=\"only_second\",\n",
    "        stride=stride,\n",
    "        return_overflowing_tokens=True,\n",
    "        return_offsets_mapping=True,\n",
    "        padding=\"max_length\",\n",
    "    )\n",
    "\n",
    "    offset_mapping = inputs.pop(\"offset_mapping\")\n",
    "    sample_map = inputs.pop(\"overflow_to_sample_mapping\")\n",
    "    answers = examples[\"answers\"]\n",
    "    start_positions = []\n",
    "    end_positions = []\n",
    "\n",
    "    for i, offset in enumerate(offset_mapping):\n",
    "        sample_idx = sample_map[i]\n",
    "        answer = answers[sample_idx]\n",
    "        start_char = answer[\"answer_start\"][0]\n",
    "        end_char = answer[\"answer_start\"][0] + len(answer[\"text\"][0])\n",
    "        sequence_ids = inputs.sequence_ids(i)\n",
    "\n",
    "        # Trouver le début et la fin du contexte\n",
    "        idx = 0\n",
    "        while sequence_ids[idx] != 1:\n",
    "            idx += 1\n",
    "        context_start = idx\n",
    "        while sequence_ids[idx] == 1:\n",
    "            idx += 1\n",
    "        context_end = idx - 1\n",
    "\n",
    "        # Si la réponse n'est pas entièrement dans le contexte, l'étiquette est (0, 0)\n",
    "        if offset[context_start][0] > start_char or offset[context_end][1] < end_char:\n",
    "            start_positions.append(0)\n",
    "            end_positions.append(0)\n",
    "        else:\n",
    "            # Sinon, ce sont les positions de début et de fin du token\n",
    "            idx = context_start\n",
    "            while idx <= context_end and offset[idx][0] <= start_char:\n",
    "                idx += 1\n",
    "            start_positions.append(idx - 1)\n",
    "\n",
    "            idx = context_end\n",
    "            while idx >= context_start and offset[idx][1] >= end_char:\n",
    "                idx -= 1\n",
    "            end_positions.append(idx + 1)\n",
    "\n",
    "    inputs[\"start_positions\"] = start_positions\n",
    "    inputs[\"end_positions\"] = end_positions\n",
    "    return inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "k4JdhH85QK63"
   },
   "outputs": [],
   "source": [
    "train_dataset = raw_datasets[\"train\"].map(\n",
    "    preprocess_training_examples,\n",
    "    batched=True,\n",
    "    remove_columns=raw_datasets[\"train\"].column_names,\n",
    ")\n",
    "len(raw_datasets[\"train\"]), len(train_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "idi_R9oBQK63"
   },
   "outputs": [],
   "source": [
    "def preprocess_validation_examples(examples):\n",
    "    questions = [q.strip() for q in examples[\"question\"]]\n",
    "    inputs = tokenizer(\n",
    "        questions,\n",
    "        examples[\"context\"],\n",
    "        max_length=max_length,\n",
    "        truncation=\"only_second\",\n",
    "        stride=stride,\n",
    "        return_overflowing_tokens=True,\n",
    "        return_offsets_mapping=True,\n",
    "        padding=\"max_length\",\n",
    "    )\n",
    "\n",
    "    sample_map = inputs.pop(\"overflow_to_sample_mapping\")\n",
    "    example_ids = []\n",
    "\n",
    "    for i in range(len(inputs[\"input_ids\"])):\n",
    "        sample_idx = sample_map[i]\n",
    "        example_ids.append(examples[\"id\"][sample_idx])\n",
    "\n",
    "        sequence_ids = inputs.sequence_ids(i)\n",
    "        offset = inputs[\"offset_mapping\"][i]\n",
    "        inputs[\"offset_mapping\"][i] = [\n",
    "            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)\n",
    "        ]\n",
    "\n",
    "    inputs[\"example_id\"] = example_ids\n",
    "    return inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "onDfkeWPQK64"
   },
   "outputs": [],
   "source": [
    "validation_dataset = raw_datasets[\"test\"].map(\n",
    "    preprocess_validation_examples,\n",
    "    batched=True,\n",
    "    remove_columns=raw_datasets[\"test\"].column_names,\n",
    ")\n",
    "len(raw_datasets[\"test\"]), len(validation_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "K68F9QgOQK64"
   },
   "outputs": [],
   "source": [
    "small_eval_set = raw_datasets[\"test\"].select(range(100))\n",
    "trained_checkpoint = \"distilbert-base-cased-distilled-squad\"\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(trained_checkpoint)\n",
    "eval_set = small_eval_set.map(\n",
    "    preprocess_validation_examples,\n",
    "    batched=True,\n",
    "    remove_columns=raw_datasets[\"test\"].column_names,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6rKHB7LNQK64"
   },
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Qzm-iQN0QK65"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import AutoModelForQuestionAnswering\n",
    "\n",
    "eval_set_for_model = eval_set.remove_columns([\"example_id\", \"offset_mapping\"])\n",
    "eval_set_for_model.set_format(\"torch\")\n",
    "\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "batch = {k: eval_set_for_model[k].to(device) for k in eval_set_for_model.column_names}\n",
    "trained_model = AutoModelForQuestionAnswering.from_pretrained(trained_checkpoint).to(\n",
    "    device\n",
    ")\n",
    "\n",
    "with torch.no_grad():\n",
    "    outputs = trained_model(**batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "bnfZPG6SQK65"
   },
   "outputs": [],
   "source": [
    "start_logits = outputs.start_logits.cpu().numpy()\n",
    "end_logits = outputs.end_logits.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "8LFCmyBBQK65"
   },
   "outputs": [],
   "source": [
    "import collections\n",
    "\n",
    "example_to_features = collections.defaultdict(list)\n",
    "for idx, feature in enumerate(eval_set):\n",
    "    example_to_features[feature[\"example_id\"]].append(idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6R-G01nKQK65"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "n_best = 20\n",
    "max_answer_length = 30\n",
    "predicted_answers = []\n",
    "\n",
    "for example in small_eval_set:\n",
    "    example_id = example[\"id\"]\n",
    "    context = example[\"context\"]\n",
    "    answers = []\n",
    "\n",
    "    for feature_index in example_to_features[example_id]:\n",
    "        start_logit = start_logits[feature_index]\n",
    "        end_logit = end_logits[feature_index]\n",
    "        offsets = eval_set[\"offset_mapping\"][feature_index]\n",
    "\n",
    "        start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()\n",
    "        end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()\n",
    "        for start_index in start_indexes:\n",
    "            for end_index in end_indexes:\n",
    "                # Ignorer les réponses qui ne sont pas entièrement dans le contexte\n",
    "                if offsets[start_index] is None or offsets[end_index] is None:\n",
    "                    continue\n",
    "                # Ignorer les réponses dont la longueur est soit < 0 soit > max_answer_length\n",
    "                if (\n",
    "                    end_index < start_index\n",
    "                    or end_index - start_index + 1 > max_answer_length\n",
    "                ):\n",
    "                    continue\n",
    "\n",
    "                answers.append(\n",
    "                    {\n",
    "                        \"text\": context[offsets[start_index][0] : offsets[end_index][1]],\n",
    "                        \"logit_score\": start_logit[start_index] + end_logit[end_index],\n",
    "                    }\n",
    "                )\n",
    "\n",
    "    best_answer = max(answers, key=lambda x: x[\"logit_score\"])\n",
    "    predicted_answers.append({\"id\": example_id, \"prediction_text\": best_answer[\"text\"]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "d6iQFh4hQK66"
   },
   "outputs": [],
   "source": [
    "from datasets import load_metric\n",
    "\n",
    "metric = load_metric(\"squad\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "iE6yr01MQK66"
   },
   "outputs": [],
   "source": [
    "theoretical_answers = [\n",
    "    {\"id\": ex[\"id\"], \"answers\": ex[\"answers\"]} for ex in small_eval_set\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "TAL68eNyQK66"
   },
   "outputs": [],
   "source": [
    "print(predicted_answers[0])\n",
    "print(theoretical_answers[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "SVa0WCW-QK66"
   },
   "outputs": [],
   "source": [
    "metric.compute(predictions=predicted_answers, references=theoretical_answers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "hdPf3xs9QK67"
   },
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "\n",
    "def compute_metrics(start_logits, end_logits, features, examples):\n",
    "    example_to_features = collections.defaultdict(list)\n",
    "    for idx, feature in enumerate(features):\n",
    "        example_to_features[feature[\"example_id\"]].append(idx)\n",
    "\n",
    "    predicted_answers = []\n",
    "    for example in tqdm(examples):\n",
    "        example_id = example[\"id\"]\n",
    "        context = example[\"context\"]\n",
    "        answers = []\n",
    "\n",
    "        # Parcourir en boucle toutes les fonctionnalités associées à cet exemple\n",
    "        for feature_index in example_to_features[example_id]:\n",
    "            start_logit = start_logits[feature_index]\n",
    "            end_logit = end_logits[feature_index]\n",
    "            offsets = features[feature_index][\"offset_mapping\"]\n",
    "\n",
    "            start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()\n",
    "            end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()\n",
    "            for start_index in start_indexes:\n",
    "                for end_index in end_indexes:\n",
    "                    # Ignorez les réponses qui ne sont pas entièrement dans le contexte\n",
    "                    if offsets[start_index] is None or offsets[end_index] is None:\n",
    "                        continue\n",
    "                    # Sauter les réponses dont la longueur est soit < 0, soit > max_answer_length\n",
    "                    if (\n",
    "                        end_index < start_index\n",
    "                        or end_index - start_index + 1 > max_answer_length\n",
    "                    ):\n",
    "                        continue\n",
    "\n",
    "                    answer = {\n",
    "                        \"text\": context[offsets[start_index][0] : offsets[end_index][1]],\n",
    "                        \"logit_score\": start_logit[start_index] + end_logit[end_index],\n",
    "                    }\n",
    "                    answers.append(answer)\n",
    "\n",
    "        # Sélectionnez la réponse avec le meilleur score\n",
    "        if len(answers) > 0:\n",
    "            best_answer = max(answers, key=lambda x: x[\"logit_score\"])\n",
    "            predicted_answers.append(\n",
    "                {\"id\": example_id, \"prediction_text\": best_answer[\"text\"]}\n",
    "            )\n",
    "        else:\n",
    "            predicted_answers.append({\"id\": example_id, \"prediction_text\": \"\"})\n",
    "\n",
    "    theoretical_answers = [{\"id\": ex[\"id\"], \"answers\": ex[\"answers\"]} for ex in examples]\n",
    "    return metric.compute(predictions=predicted_answers, references=theoretical_answers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "bKYqSj6LQK67"
   },
   "outputs": [],
   "source": [
    "compute_metrics(start_logits, end_logits, eval_set, small_eval_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "VbQQgkhsQK67"
   },
   "outputs": [],
   "source": [
    "model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "pJqyc7vrQK68"
   },
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments\n",
    "\n",
    "args = TrainingArguments(\n",
    "    \"camembert-base-finetuned-piaf\",\n",
    "    evaluation_strategy=\"no\",\n",
    "    save_strategy=\"epoch\",\n",
    "    learning_rate=2e-5,\n",
    "    num_train_epochs=3,\n",
    "    weight_decay=0.01,\n",
    "    fp16=True,\n",
    "    push_to_hub=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5lsISGbRQK68"
   },
   "outputs": [],
   "source": [
    "from transformers import Trainer\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=args,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=validation_dataset,\n",
    "    tokenizer=tokenizer,\n",
    ")\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "BZCgc7sJQK68"
   },
   "outputs": [],
   "source": [
    "predictions, _, _ = trainer.predict(validation_dataset)\n",
    "start_logits, end_logits = predictions\n",
    "compute_metrics(start_logits, end_logits, validation_dataset, raw_datasets[\"test\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "OzXfIms3QK69"
   },
   "outputs": [],
   "source": [
    "trainer.push_to_hub(commit_message=\"Training complete\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "coNkcqleQK69"
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "from transformers import default_data_collator\n",
    "\n",
    "train_dataset.set_format(\"torch\")\n",
    "validation_set = validation_dataset.remove_columns([\"example_id\", \"offset_mapping\"])\n",
    "validation_set.set_format(\"torch\")\n",
    "\n",
    "train_dataloader = DataLoader(\n",
    "    train_dataset,\n",
    "    shuffle=True,\n",
    "    collate_fn=default_data_collator,\n",
    "    batch_size=8,\n",
    ")\n",
    "eval_dataloader = DataLoader(\n",
    "    validation_set, collate_fn=default_data_collator, batch_size=8\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "cml81PT9QK6-"
   },
   "outputs": [],
   "source": [
    "model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-i_E-Qr9QK6-"
   },
   "outputs": [],
   "source": [
    "from torch.optim import AdamW\n",
    "\n",
    "optimizer = AdamW(model.parameters(), lr=2e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "b0jtsA_BQK6-"
   },
   "outputs": [],
   "source": [
    "from accelerate import Accelerator\n",
    "\n",
    "accelerator = Accelerator(fp16=True)\n",
    "model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(\n",
    "    model, optimizer, train_dataloader, eval_dataloader\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "88PpNpYfQK6_"
   },
   "outputs": [],
   "source": [
    "from transformers import get_scheduler\n",
    "\n",
    "num_train_epochs = 3\n",
    "num_update_steps_per_epoch = len(train_dataloader)\n",
    "num_training_steps = num_train_epochs * num_update_steps_per_epoch\n",
    "\n",
    "lr_scheduler = get_scheduler(\n",
    "    \"linear\",\n",
    "    optimizer=optimizer,\n",
    "    num_warmup_steps=0,\n",
    "    num_training_steps=num_training_steps,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9ctafoG6QK6_"
   },
   "outputs": [],
   "source": [
    "from huggingface_hub import Repository, get_full_repo_name\n",
    "\n",
    "model_name = \"camembert-base-finetuned-piaf-accelerate\"\n",
    "repo_name = get_full_repo_name(model_name)\n",
    "repo_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "GkEGULCrQK6_"
   },
   "outputs": [],
   "source": [
    "output_dir = \"camembert-base-finetuned-piaf-accelerate\"\n",
    "repo = Repository(output_dir, clone_from=repo_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-t6DsGmsQK7A"
   },
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "import torch\n",
    "\n",
    "progress_bar = tqdm(range(num_training_steps))\n",
    "\n",
    "for epoch in range(num_train_epochs):\n",
    "    # Entraînement\n",
    "    model.train()\n",
    "    for step, batch in enumerate(train_dataloader):\n",
    "        outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        accelerator.backward(loss)\n",
    "\n",
    "        optimizer.step()\n",
    "        lr_scheduler.step()\n",
    "        optimizer.zero_grad()\n",
    "        progress_bar.update(1)\n",
    "\n",
    "    # Evaluation\n",
    "    model.eval()\n",
    "    start_logits = []\n",
    "    end_logits = []\n",
    "    accelerator.print(\"Evaluation!\")\n",
    "    for batch in tqdm(eval_dataloader):\n",
    "        with torch.no_grad():\n",
    "            outputs = model(**batch)\n",
    "\n",
    "        start_logits.append(accelerator.gather(outputs.start_logits).cpu().numpy())\n",
    "        end_logits.append(accelerator.gather(outputs.end_logits).cpu().numpy())\n",
    "\n",
    "    start_logits = np.concatenate(start_logits)\n",
    "    end_logits = np.concatenate(end_logits)\n",
    "    start_logits = start_logits[: len(validation_dataset)]\n",
    "    end_logits = end_logits[: len(validation_dataset)]\n",
    "\n",
    "    metrics = compute_metrics(\n",
    "        start_logits, end_logits, validation_dataset, raw_datasets[\"test\"]\n",
    "    )\n",
    "    print(f\"epoch {epoch}:\", metrics)\n",
    "\n",
    "    # Sauvegarder et télécharger\n",
    "    accelerator.wait_for_everyone()\n",
    "    unwrapped_model = accelerator.unwrap_model(model)\n",
    "    unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)\n",
    "    if accelerator.is_main_process:\n",
    "        tokenizer.save_pretrained(output_dir)\n",
    "        repo.push_to_hub(\n",
    "            commit_message=f\"Training in progress epoch {epoch}\", blocking=False\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "h7iQN81gQK7B"
   },
   "outputs": [],
   "source": [
    "accelerator.wait_for_everyone()\n",
    "unwrapped_model = accelerator.unwrap_model(model)\n",
    "unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "u3EKyP6RQK7B"
   },
   "outputs": [],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "# Remplacez par votre propre checkpoint\n",
    "model_checkpoint = \"huggingface-course/camembert-finetuned-piaf\"\n",
    "question_answerer = pipeline(\"question-answering\", model=model_checkpoint)\n",
    "\n",
    "context = \"\"\"\n",
    "🤗 Transformers est soutenu par les trois bibliothèques d'apprentissage profond les plus populaires - Jax, PyTorch et TensorFlow - avec une intégration transparente entre elles. Il est simple d'entraîner vos modèles avec l'une avant de les charger pour l'inférence avec l'autre.\n",
    "\"\"\"\n",
    "question = \"Quelles sont les bibliothèques d'apprentissage profond derrière 🤗 Transformers ?\"\n",
    "question_answerer(question=question, context=context)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
